{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import torch.utils.data\n",
    "import numpy as np\n",
    "import math\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_var(x, requires_grad=False):\n",
    "    \"\"\"\n",
    "    Automatically choose cpu or cuda\n",
    "    \"\"\"\n",
    "    if torch.cuda.is_available():\n",
    "        x = x.cuda()\n",
    "    return x.clone().detach().requires_grad_(requires_grad)\n",
    "\n",
    "\n",
    "class MaskedLinear(nn.Linear):\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super(MaskedLinear, self).__init__(in_features, out_features, bias)\n",
    "        self.register_buffer('mask', None)\n",
    "\n",
    "    def set_mask(self, mask):\n",
    "        self.mask = to_var(mask, requires_grad=False)\n",
    "        self.weight.data = self.weight.data * self.mask.data\n",
    "\n",
    "    def get_mask(self):\n",
    "        return self.mask\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.mask is not None:\n",
    "            weight = self.weight * self.mask\n",
    "            return F.linear(x, weight, self.bias)\n",
    "        else:\n",
    "            return F.linear(x, self.weight, self.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MLP, self).__init__()\n",
    "        self.linear1 = MaskedLinear(28 * 28, 200)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "        self.linear2 = MaskedLinear(200, 200)\n",
    "        self.relu2 = nn.ReLU(inplace=True)\n",
    "        self.linear3 = MaskedLinear(200, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = x.view(x.size(0), -1)\n",
    "        out = self.relu1(self.linear1(out))\n",
    "        out = self.relu2(self.linear2(out))\n",
    "        out = self.linear3(out)\n",
    "        return out\n",
    "\n",
    "    def set_masks(self, masks):\n",
    "        # Should be a less manual way to set masks\n",
    "        # Leave it for the future\n",
    "        self.linear1.set_mask(masks[0])\n",
    "        self.linear2.set_mask(masks[1])\n",
    "        self.linear3.set_mask(masks[2])\n",
    "    \n",
    "    def load_state_dict(self, state_dict, strict=True):        \n",
    "        super().load_state_dict(state_dict, strict=False)       \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    total = 0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.cross_entropy(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        total += len(data)\n",
    "        progress = math.ceil(batch_idx / len(train_loader) * 50)\n",
    "        print(\"\\rTrain epoch %d: %d/%d, [%-51s] %d%%\" %\n",
    "              (epoch, total, len(train_loader.dataset),\n",
    "               '-' * progress + '>', progress * 2), end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "    return test_loss, correct / len(test_loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weight_prune(model, pruning_perc):\n",
    "    '''\n",
    "    Prune pruning_perc % weights layer-wise\n",
    "    '''\n",
    "    threshold_list = []\n",
    "    for p in model.parameters():\n",
    "        if len(p.data.size()) != 1: # bias\n",
    "            weight = p.cpu().data.abs().numpy().flatten()\n",
    "            threshold = np.percentile(weight, pruning_perc)\n",
    "            threshold_list.append(threshold)\n",
    "\n",
    "    # generate mask\n",
    "    masks = []\n",
    "    idx = 0\n",
    "    for p in model.parameters():\n",
    "        if len(p.data.size()) != 1:\n",
    "            pruned_inds = p.data.abs() > threshold_list[idx]\n",
    "            masks.append(pruned_inds.float())\n",
    "            idx += 1\n",
    "    return masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    epochs = 1\n",
    "    batch_size = 64\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=batch_size, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data/MNIST', train=False, download=False, transform=transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])),\n",
    "        batch_size=1000, shuffle=True)\n",
    "\n",
    "    model = MLP().to(device)\n",
    "    optimizer = torch.optim.Adadelta(model.parameters())\n",
    "    \n",
    "    for epoch in range(1, epochs + 1):\n",
    "        train(model, device, train_loader, optimizer, epoch)\n",
    "        _, acc = test(model, device, test_loader)\n",
    "    \n",
    "    torch.save(model.state_dict(), 'not-pruned.ckpt')\n",
    "    print(\"\\n=====Pruning 60%=======\\n\")\n",
    "    pruned_model = deepcopy(model)\n",
    "    mask = weight_prune(pruned_model, 60)\n",
    "    pruned_model.set_masks(mask)\n",
    "    test(pruned_model, device, test_loader)\n",
    "    torch.save(pruned_model.state_dict(), 'pruned.ckpt')\n",
    "        \n",
    "    return model, pruned_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, pruned_model = main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'not-pruned.ckpt')\n",
    "\n",
    "print(model.state_dict()) # 每一层的mask均为None, 没有打印出来\n",
    "\n",
    "torch.save(pruned_model.state_dict(), 'pruned.ckpt')\n",
    "print(pruned_model.state_dict()) # 每一层的mask为一个FloatTensor\n",
    "\n",
    "# 观察'not-pruned.ckpt' 与 'pruned.ckpt'的文件大小， 'pruned.ckpt'约比'not-pruned.ckpt'大两倍"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遍历一遍state_dict， 看看里头有什么\n",
    "sd = pruned_model.state_dict()\n",
    "for k, v in sd.items():\n",
    "    print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 修改pruned_model.state_dict()中的weight和mask，然后再保存\n",
    "for i in range(1, 3):\n",
    "    weight = sd['linear'+str(i)+'.weight']\n",
    "    mask = sd['linear'+str(i)+'.mask']\n",
    "    \n",
    "    mask = mask > 0 # FloatTensor -> BoolTensor, 节省32倍\n",
    "    sd['linear'+str(i)+'.weight'] = weight[mask] # 只保存非零数值\n",
    "    sd['linear'+str(i)+'.mask'] = mask\n",
    "    \n",
    "    # 注意， 如果mask的值几乎全都是零，还可以进一步进行稀疏编码\n",
    "\n",
    "# 此时再保存sd\n",
    "torch.save(sd, 'modified-pruned.ckpt')\n",
    "\n",
    "# 再查看一下'modified-pruned.ckpt'的大小"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注意load_state_dict时应当把weight进行复原, 类似于\n",
    "# weight = torch.zeros_like(mask.float())\n",
    "# weight[mask] = non_zero_weight\n",
    "# mask = mask.float()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_weights(model):\n",
    "    modules = [module for module in model.modules()]\n",
    "    num_sub_plot = 0\n",
    "    for i, layer in enumerate(modules):\n",
    "        if hasattr(layer, 'weight'):\n",
    "            plt.subplot(131+num_sub_plot)\n",
    "            w = layer.weight.data\n",
    "            w_one_dim = w.cpu().numpy().flatten()\n",
    "            plt.hist(w_one_dim[w_one_dim!=0], bins=50)\n",
    "            num_sub_plot += 1\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_weights(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_weights(pruned_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
